import os
import torch
from torch.utils.data import Dataset
import logging

# Use a logger for cleaner output, consistent with training scripts
logger = logging.getLogger(__name__)

class MADTokenDataset(Dataset):
    """
    Custom PyTorch Dataset for loading pre-processed MAD data.
    It scans a directory, finds pairs of image and signal tensor files (.pt),
    and prepares them for the DataLoader.
    """
    def __init__(self, root_dir, num_classes,
                 img_seq_len=256,          # Length of the image patch sequence (e.g., 256 patches)
                 sig_seq_len=2560,         # Length of the signal patch sequence (e.g., 2560 patches)
                 img_patch_flat_dim=768,   # Dimension of a single flattened image patch (e.g., 3*16*16)
                 sig_patch_dim=60,         # Dimension of a single signal patch (e.g., 60 data points)
                 max_samples_per_class=None, usage="train"):
        
        self.file_pairs = []  # Stores tuples of (image_path, signal_path)
        self.labels = []
        self.usage = usage
        self.class_names = [] # Stores the names of the classes found

        # Store sequence length and patch dimension info for validation
        self.img_seq_len = img_seq_len
        self.sig_seq_len = sig_seq_len
        self.img_patch_flat_dim = img_patch_flat_dim
        self.sig_patch_dim = sig_patch_dim
        
        if not os.path.isdir(root_dir):
            logger.warning(f"Dataset directory not found: {root_dir}")
            return

        # Sort class directories alphabetically for consistent label mapping (e.g., Normal, Noise, ...)
        class_dirs = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])

        if len(class_dirs) == 0:
            logger.warning(f"No class directories found in {root_dir}. The dataset will be empty.")
            return
            
        for label, class_dir_name in enumerate(class_dirs):
            # This logic handles both 'Normal' and '0_Normal' style folder names
            actual_class_name = "_".join(class_dir_name.split('_')[1:]) if '_' in class_dir_name and class_dir_name.split('_')[0].isdigit() else class_dir_name
            self.class_names.append(actual_class_name)

            class_path = os.path.join(root_dir, class_dir_name)
            img_files = {}  # {base_name: path/to/img.pt}
            sig_files = {}  # {base_name: path/to/sig.pt}

            # Find all pre-processed image and signal files, following the convention from the preprocessing script
            for f_name in sorted(os.listdir(class_path)):
                if f_name.endswith("_img_raw_patches.pt"):
                    base_name = f_name.replace("_img_raw_patches.pt", "")
                    img_files[base_name] = os.path.join(class_path, f_name)
                elif f_name.endswith("_sig_raw_patches.pt"):
                    base_name = f_name.replace("_sig_raw_patches.pt", "")
                    sig_files[base_name] = os.path.join(class_path, f_name)
            
            # Match image and signal files by their base name for consistency
            samples_added = 0
            for base_name in sorted(img_files.keys()):
                if base_name in sig_files:
                    if max_samples_per_class is not None and samples_added >= max_samples_per_class:
                        break
                    
                    self.file_pairs.append((img_files[base_name], sig_files[base_name]))
                    self.labels.append(label) # Assign integer label
                    samples_added += 1
        
        # Log dataset initialization info
        logger.info(f"Initialized '{self.usage}' dataset with {len(self.file_pairs)} sample pairs from {len(self.class_names)} classes.")
        if self.class_names:
            logger.info(f"Classes found: {self.class_names}")

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.file_pairs)

    def __getitem__(self, idx):
        """
        Loads and returns a sample from the dataset at the given index.
        A sample consists of an image patch tensor, a signal patch tensor, and its corresponding label.
        """
        img_patch_path, sig_patch_path = self.file_pairs[idx]
        label = self.labels[idx]

        try:
            img_patches = torch.load(img_patch_path, map_location=torch.device('cpu'))
            sig_patches = torch.load(sig_patch_path, map_location=torch.device('cpu'))

            # Validate the shape of each loaded tensor as a sanity check
            expected_img_shape = (self.img_seq_len, self.img_patch_flat_dim)
            expected_sig_shape = (self.sig_seq_len, self.sig_patch_dim)

            if img_patches.shape != expected_img_shape:
                raise ValueError(f"Image patches shape mismatch for {os.path.basename(img_patch_path)}. "
                                 f"Got {img_patches.shape}, expected {expected_img_shape}")
            
            if sig_patches.shape != expected_sig_shape:
                raise ValueError(f"Signal patches shape mismatch for {os.path.basename(sig_patch_path)}. "
                                 f"Got {sig_patches.shape}, expected {expected_sig_shape}")
            
            # Return data as a tuple of (tensors_tuple, label)
            return (img_patches, sig_patches), torch.tensor(label, dtype=torch.long)

        except Exception as e:
            logger.error(f"Error loading or processing files for index {idx} ({img_patch_path}, {sig_patch_path}): {e}")
            # Re-raise the exception to stop the training process if a file is corrupt or invalid.
            raise e

    def get_class_names(self):
        """
        Returns the list of class names found in the dataset.
        (e.g., ['Normal', 'Noise', 'Surface', 'Corona', 'Void'])
        """
        return self.class_names